from iclr23code import Neuron
import torch
from torch.nn.utils import clip_grad_norm_
from iclr23code.utils import unfold, cal_gap, cal_k, back_cal_gap
from math import sqrt


# ========================================= Adjust ==========================================
class Adjust(object):
    def __init__(self, limit=0):
        if limit < 0:
            raise ValueError("`limit` can't smaller than 0.")
        self.limit = limit

    @staticmethod
    def start():
        Neuron.open_gate()
        Neuron.open_count_gate()

    @staticmethod
    def close():
        Neuron.close_gate()
        Neuron.close_count_gate()

    @staticmethod
    def reset():
        for layer in Neuron.instance:
            layer.reset_count()
            layer.reset_num()
        Neuron.close_gate()
        Neuron.close_count_gate()

    def action(self):
        pass

    def step(self, ep):
        if ep < self.limit:
            pass
        else:
            self.action()


class NewAdjust(Adjust):
    def __init__(self, limit=0, key='0', momentum=0.0, limit_k=0.2, epi=1e-8):
        super(NewAdjust, self).__init__(limit)
        if momentum < 0.0 or momentum > 1.0:
            raise ValueError("`momentum` must in [0, 1].")
        if limit_k <= 0.0 or limit_k >= 1.0:
            raise ValueError("`limit_k` must in (0, 1).")
        if epi < 0:
            raise ValueError("epi not smaller than 0.")

        self.momentum = momentum
        self.limit_k = limit_k
        self.epi = epi
        """ Need to be initialized after model initialization."""
        self.generate_dict()
        self.idx_list = self.inner_dict[key]
        if len(self.idx_list) == 0:
            raise ValueError("If not set update list, not use `NewAdjust`")

        self.ori_slope = torch.zeros(len(Neuron.instance))
        self.k_method = cal_k
        self.gap_method = cal_gap

        self.min_k = torch.zeros(len(Neuron.instance))
        self.flag = [False for _ in range(len(Neuron.instance))]  # distinguish zero_*_block's neuron

    def generate_dict(self):
        self.inner_dict = {
            '0': list(range(len(Neuron.instance))),
            '1': [0],
        }

    def get_list(self):
        return self.idx_list

    def update_k(self, temp_k, idx):
        """
        :param temp_k: Not torch.Tensor
        """
        if self.min_k[idx] < self.limit_k:
            self.min_k[idx] = self.limit_k
            return self.min_k[idx].item()
        if self.min_k[idx] < temp_k:
            return self.min_k[idx].item()
        else:
            """
            temp_k may smaller than self.limit_k, but in this iteration, cal_gap not change the
            slope, int the next iteration, will compare with self.limit_k.
            """
            self.min_k[idx] = temp_k
            return temp_k

    def get_slope(self, new_gap, idx, device):
        return torch.tensor(1 / new_gap, dtype=torch.float, device=device)

    def action(self):
        cur_list = self.get_list()
        for idx in cur_list:
            layer = Neuron.instance[idx]
            device = layer.slope.device
            threshold = layer.threshold.item()
            slope = layer.slope.item()
            last_gap = 1 / slope
            mean = layer.mean.item()
            var = layer.var.item()

            sigma = sqrt(var)
            if sigma <= self.epi:
                if sigma < 0:
                    raise ValueError("Some thing wrong, sigma less than 0.")
                continue
            temp_k = self.k_method(mean, sigma, last_gap, threshold)
            if temp_k > self.limit_k:
                self.flag[idx] = True

            if self.ori_slope[idx] == 0:
                self.ori_slope[idx] = slope
                self.min_k[idx] = temp_k
                continue

            cur_k = self.update_k(temp_k, idx)  # Return value not a tensor
            if not self.flag[idx]:
                continue
            new_gap = self.gap_method(mean, sigma, cur_k, last_gap, threshold, self.epi)
            new_slope = self.get_slope(new_gap, idx, device)
            layer.slope = self.momentum * layer.slope + (1 - self.momentum) * new_slope
            # layer.update_grad()


class BNAdjust(NewAdjust):
    def __init__(self, limit=0, key='0', momentum=0.0, limit_k=0.2):
        super(BNAdjust, self).__init__(limit, key, momentum, limit_k)
        self.gap_method = back_cal_gap


# ============================================================================================
def store_info(batch_num, num_step, epoch, fire_recorder, slope_recorder, mean_recorder,
               var_recorder):
    """
    :param batch_num: batch_size
    :param num_step: used t
    :param epoch: current epoch
    :param fire_recorder: [E, N]
    :param slope_recorder: [E, N]
    :param mean_recorder: [E, N]
    :param var_recorder: [E, N]
    """
    i = 0
    for layer in Neuron.instance:
        rate = (layer.fire_num / layer.output_dim / num_step / batch_num).item()
        fire_recorder[epoch, i] = rate
        slope_recorder[epoch, i] = layer.slope.item()
        mean_recorder[epoch, i] = layer.mean.item()
        var_recorder[epoch, i] = layer.var.item()
        i += 1
        print(f"Fire rate: {rate}, shape: {layer.output_dim.item()}")


def train(net, num_step, data, target, optimizer, criterion, pre_process=None):
    """
    Data processed (fold and set device) before call train.
    :param num_step: Time dimension
    :param data: [T * B, C, H, W]
    :param pre_process: Some process for result
    :return: loss and correct, type is not torch.Tensor
    """

    layer_init_()
    mem_total = 0
    res = unfold(net(data), num_step)  # [T, B, N]
    for step in range(num_step):
        mem_total += res[step]
    predicted = mem_total.detach().max(dim=1).indices
    correct = (predicted == target).sum().item()

    if pre_process is not None:
        if callable(pre_process):
            mem_total = pre_process(mem_total)
        else:
            raise TypeError("`pre_process` must be callable.")
    loss = criterion(mem_total, target)
    optimizer.zero_grad()
    loss.backward()
    clip_grad_norm_(net.parameters(), 1)
    optimizer.step()
    layer_detach_()

    return loss.item(), correct


def trainImg(net, num_step, data, target, optimizer, criterion, pre_process=None):
    """
    Data processed (fold and set device) before call train.
    :param num_step: Time dimension
    :param data: [T * B, C, H, W]
    :param pre_process: Some process for result
    :return: loss and correct, type is not torch.Tensor
    """

    layer_init_()
    mem_total = 0
    res = unfold(net(data), num_step)  # [T, B, N]
    for step in range(num_step):
        mem_total += res[step]
    predicted = mem_total.detach().max(dim=1).indices
    correct = (predicted == target).sum().item()

    if pre_process is not None:
        if callable(pre_process):
            mem_total = pre_process(mem_total)
        else:
            raise TypeError("`pre_process` must be callable.")
    loss = criterion(mem_total, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    layer_detach_()

    return loss.item(), correct

def train_en(net, num_step, data, target, optimizer, criterion, pre_process=None, lamb=0.001, means=1):
    layer_init_()
    mem_total = 0
    loss = 0
    res = unfold(net(data), num_step)

    for step in range(num_step):
        mem_total += res[step].detach()
        if pre_process is not None:
            if callable(pre_process):
                loss += criterion(pre_process(res[step]), target)
            else:
                raise TypeError("`pre_process` must be callable.")
        else:
            loss += criterion(res[step], target)

    loss /= num_step  # Can add regulation

    if lamb != 0:
        MMDLoss = torch.nn.MSELoss()
        y = torch.zeros_like(res).fill_(means)
        Loss_mmd = MMDLoss(res, y)  # L_mse
    else:
        Loss_mmd = 0

    loss = (1 - lamb) * loss + lamb * Loss_mmd  # L_Total
    optimizer.zero_grad()
    loss.backward()
    # clip_grad_norm_(net.parameters(), 1)
    optimizer.step()
    layer_detach_()

    predicted = mem_total.max(dim=1).indices
    correct = (predicted == target).sum().item()
    return loss.item(), correct

def test(net, num_step, data, target):
    """
    :return: Not torch.Tensor
    """
    layer_init_()
    mem_total = 0
    res = unfold(net(data), num_step)

    for step in range(num_step):
        mem_total += res[step].detach()
    predicted = mem_total.max(dim=1).indices
    correct = (predicted == target).sum().item()
    return correct


def layer_init_():
    for layer in Neuron.instance:
        layer.detach_param()
        layer.init_param()


def layer_detach_():
    for layer in Neuron.instance:
        layer.detach_param()
